import torch
import collections
import math
from typing import Optional
from itertools import repeat

from .dlrt_module import DLRTModule
torch.set_default_dtype(torch.float32)

collections.Iterable = collections.abc.Iterable
collections.Mapping = collections.abc.Mapping
collections.MutableSet = collections.abc.MutableSet
collections.MutableMapping = collections.abc.MutableMapping


def _ntuple(n, name="parse"):
    def parse(x):
        if isinstance(x, collections.abc.Iterable):
            return tuple(x)
        return tuple(repeat(x, n))

    parse.__name__ = name
    return parse


_single = _ntuple(1, "_single")
_pair = _ntuple(2, "_pair")
_triple = _ntuple(3, "_triple")
_quadruple = _ntuple(4, "_quadruple")


class _ConvNd(DLRTModule):
    # Taken directly from torch
    # (https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py)

    __constants__ = [
        "stride",
        "padding",
        "dilation",
        "groups",
        "padding_mode",
        "output_padding",
        "in_channels",
        "out_channels",
        "kernel_size",
    ]
    __annotations__ = {"bias": Optional[torch.Tensor]}

    # _in_channels: int
    # _reversed_padding_repeated_twice: list[int]
    # out_channels: int
    # kernel_size: tuple[int, ...]
    # stride: tuple[int, ...]
    # padding: str | tuple[int, ...]
    # dilation: tuple[int, ...]
    # transposed: bool
    # output_padding: tuple[int, ...]
    # groups: int
    # padding_mode: str
    # weight: torch.Tensor
    # bias: torch.Tensor | None
    # ==== low_rank =================================================
    low_rank: int
    rmax: int
    convert_from_weights: torch.Tensor

    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size,
            stride,
            padding,
            dilation,
            transposed: bool,
            output_padding,
            groups: int,
            bias: bool,
            padding_mode: str,
            device=None,
            dtype=None,
            # ====================== DLRT params =========================================
            low_rank_percent: None = None,  
            fixed_rank: bool = False,
            tau=0.1
    ) -> None:
        # from torch =================================================================
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__(tau=tau)
        if in_channels % groups != 0:
            raise ValueError("in_channels must be divisible by groups")
        if out_channels % groups != 0:
            raise ValueError("out_channels must be divisible by groups")
        valid_padding_strings = {"same", "valid"}
        if isinstance(padding, str):
            if padding not in valid_padding_strings:
                raise ValueError(
                    "Invalid padding string {!r}, should be one of {}".format(
                        padding,
                        valid_padding_strings,
                    ),
                )
            if padding == "same" and any(s != 1 for s in stride):
                raise ValueError("padding='same' is not supported for strided convolutions")

        valid_padding_modes = {"zeros", "reflect", "replicate", "circular"}
        if padding_mode not in valid_padding_modes:
            raise ValueError(
                "padding_mode must be one of {}, but got padding_mode='{}'".format(
                    valid_padding_modes,
                    padding_mode,
                ),
            )
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.transposed = transposed
        self.output_padding = output_padding
        self.groups = groups
        self.padding_mode = padding_mode
        self.tau = tau

        if isinstance(self.padding, str):
            self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size)
            if padding == "same":
                for d, k, i in zip(
                        dilation,
                        kernel_size,
                        range(len(kernel_size) - 1, -1, -1),
                ):
                    total_padding = d * (k - 1)
                    left_pad = total_padding // 2
                    self._reversed_padding_repeated_twice[2 * i] = left_pad
                    self._reversed_padding_repeated_twice[2 * i + 1] = total_padding - left_pad
        else:
            self._reversed_padding_repeated_twice = torch.nn.modules.utils._reverse_repeat_tuple(self.padding, 2)

        # new changes =================================================================
        self.fixed_rank = fixed_rank
        kernel_size_number = math.prod(self.kernel_size)
        self.kernel_size_number = kernel_size_number

        self.basic_number_weights = in_channels * (out_channels // groups + kernel_size_number)

        if bias:
            self.bias = torch.nn.Parameter(torch.empty(out_channels, **factory_kwargs))
        else:
            self.bias = torch.nn.Parameter(
                torch.zeros(out_channels, requires_grad=False, **factory_kwargs),
                requires_grad=False,
            )

    def extra_repr(self):
        s = (
            f"{self.in_channels}, {self.out_channels}, kernel_size={self.kernel_size}, "
            f"stride={self.stride}"
        )
        if self.padding != (0,) * len(self.padding):
            s += f", padding={self.padding}"
        if self.dilation != (1,) * len(self.dilation):
            s += f", dilation={self.dilation}"
        if self.output_padding != (0,) * len(self.output_padding):
            s += f", output_padding={self.output_padding}"
        if self.groups != 1:
            s += f", groups={self.groups}"
        if self.bias is None:
            s += ", bias=False"
        if self.padding_mode != "zeros":
            s += f", padding_mode={self.padding_mode}"
        s += f", low_rank={self.rank}"
        # return s.format(**self.__dict__)
        return s

    def __setstate__(self, state):
        super().__setstate__(state)
        if not hasattr(self, "padding_mode"):
            self.padding_mode = "zeros"
